{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 3944,
     "status": "ok",
     "timestamp": 1616888724581,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "nh8rENT0TsTx"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E:\\Dropbox\\Optimal Training Sets\\Replication File v4\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import asyncio\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random\n",
    "import math\n",
    "import os\n",
    "\n",
    "# Set the working directory\n",
    "os.chdir(\"..\")\n",
    "print(os.getcwd())\n",
    "\n",
    "# Set the random seed\n",
    "random.seed(10012)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 1348,
     "status": "ok",
     "timestamp": 1616888740888,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "5iHOeMN9DOFO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cpu\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names = ['stwts', 'eo']\n",
    "\n",
    "#embed_types = ['cvec_pca16', 'cvec_nmf16']#, 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal', 'lda100']\n",
    "embed_types = ['cvec_pca16', 'cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal', 'lda100']\n",
    "#embed_types = ['cvec_pca16','cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'lda100', 'bert']\n",
    "#embed_types = ['lda100']\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 2777,
     "status": "ok",
     "timestamp": 1616888754228,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "tRkZXpzsgdb-"
   },
   "outputs": [],
   "source": [
    "class SparseModel(nn.Module):\n",
    "  def __init__(self, hdim):\n",
    "    super(SparseModel, self).__init__()\n",
    "    self.hdim = hdim\n",
    "    w = torch.zeros((hdim, hdim))\n",
    "    torch.nn.init.xavier_normal_(w)\n",
    "    w.requires_grad = True\n",
    "    self.weights = nn.Parameter(w)\n",
    "\n",
    "  def forward(self, input):\n",
    "    x = torch.matmul(input, self.weights)\n",
    "    return x\n",
    "\n",
    "def custom_loss(output, input):\n",
    "  return 0.5 * torch.square(torch.norm(input-output, p='fro')) / input.size()[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 236,
     "status": "ok",
     "timestamp": 1616888815387,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "EGpskROSILlg"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.005\n",
    "num_epochs = 20000\n",
    "lmbda = 0.01\n",
    "\n",
    "def Diff(li1, li2):\n",
    "    return list(set(li1) - set(li2)) + list(set(li2) - set(li1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 236133,
     "status": "ok",
     "timestamp": 1616889052393,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "i2qh4x3WHO4Q",
    "outputId": "b9ce0e8f-b980-4d58-bde4-66f49532ed25",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating: stwts cvec_pca16 iter1\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_pca16 iter2\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter3\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter4\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter5\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter6\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_pca16 iter7\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter8\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter9\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_pca16 iter10\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter1\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_nmf16 iter2\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter3\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter4\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter5\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter6\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_nmf16 iter7\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter8\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter9\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_nmf16 iter10\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter1\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_umap16 iter2\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter3\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter4\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter5\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter6\n",
      "torch.Size([16, 5012])\n",
      "Generating: stwts cvec_umap16 iter7\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter8\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter9\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_umap16 iter10\n",
      "torch.Size([16, 5010])\n",
      "Generating: stwts cvec_tsne16 iter1\n",
      "torch.Size([2, 5012])\n",
      "Generating: stwts cvec_tsne16 iter2\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter3\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter4\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter5\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter6\n",
      "torch.Size([2, 5012])\n",
      "Generating: stwts cvec_tsne16 iter7\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter8\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter9\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts cvec_tsne16 iter10\n",
      "torch.Size([2, 5010])\n",
      "Generating: stwts bert iter1\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts bert iter2\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter3\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter4\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter5\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter6\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts bert iter7\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter8\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter9\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts bert iter10\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter1\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts roberta iter2\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter3\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter4\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter5\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter6\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts roberta iter7\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter8\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter9\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts roberta iter10\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter1\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts distil iter2\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter3\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter4\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter5\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter6\n",
      "torch.Size([768, 5012])\n",
      "Generating: stwts distil iter7\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter8\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter9\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts distil iter10\n",
      "torch.Size([768, 5010])\n",
      "Generating: stwts glove6B iter1\n",
      "torch.Size([300, 5012])\n",
      "Generating: stwts glove6B iter2\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter3\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter4\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter5\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter6\n",
      "torch.Size([300, 5012])\n",
      "Generating: stwts glove6B iter7\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter8\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter9\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts glove6B iter10\n",
      "torch.Size([300, 5010])\n",
      "Generating: stwts universal iter1\n",
      "torch.Size([512, 5012])\n",
      "Generating: stwts universal iter2\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter3\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter4\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter5\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter6\n",
      "torch.Size([512, 5012])\n",
      "Generating: stwts universal iter7\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter8\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter9\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts universal iter10\n",
      "torch.Size([512, 5010])\n",
      "Generating: stwts lda100 iter1\n",
      "torch.Size([100, 5012])\n",
      "Generating: stwts lda100 iter2\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter3\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter4\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter5\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter6\n",
      "torch.Size([100, 5012])\n",
      "Generating: stwts lda100 iter7\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter8\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter9\n",
      "torch.Size([100, 5010])\n",
      "Generating: stwts lda100 iter10\n",
      "torch.Size([100, 5010])\n",
      "Generating: eo cvec_pca16 iter1\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter2\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter3\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter4\n",
      "torch.Size([16, 8175])\n",
      "Generating: eo cvec_pca16 iter5\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter6\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter7\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter8\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter9\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_pca16 iter10\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter1\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter2\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter3\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter4\n",
      "torch.Size([16, 8175])\n",
      "Generating: eo cvec_nmf16 iter5\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter6\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter7\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter8\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter9\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_nmf16 iter10\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter1\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter2\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter3\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter4\n",
      "torch.Size([16, 8175])\n",
      "Generating: eo cvec_umap16 iter5\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter6\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter7\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter8\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter9\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_umap16 iter10\n",
      "torch.Size([16, 8173])\n",
      "Generating: eo cvec_tsne16 iter1\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter2\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter3\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter4\n",
      "torch.Size([2, 8175])\n",
      "Generating: eo cvec_tsne16 iter5\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter6\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter7\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter8\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter9\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo cvec_tsne16 iter10\n",
      "torch.Size([2, 8173])\n",
      "Generating: eo bert iter1\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter2\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter3\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter4\n",
      "torch.Size([768, 8175])\n",
      "Generating: eo bert iter5\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter6\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter7\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter8\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter9\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo bert iter10\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter1\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter2\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter3\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter4\n",
      "torch.Size([768, 8175])\n",
      "Generating: eo roberta iter5\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter6\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter7\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter8\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter9\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo roberta iter10\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter1\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter2\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter3\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter4\n",
      "torch.Size([768, 8175])\n",
      "Generating: eo distil iter5\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter6\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter7\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter8\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter9\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo distil iter10\n",
      "torch.Size([768, 8173])\n",
      "Generating: eo glove6B iter1\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter2\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter3\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter4\n",
      "torch.Size([300, 8175])\n",
      "Generating: eo glove6B iter5\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter6\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter7\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter8\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter9\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo glove6B iter10\n",
      "torch.Size([300, 8173])\n",
      "Generating: eo universal iter1\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter2\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter3\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter4\n",
      "torch.Size([512, 8175])\n",
      "Generating: eo universal iter5\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter6\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter7\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter8\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter9\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo universal iter10\n",
      "torch.Size([512, 8173])\n",
      "Generating: eo lda100 iter1\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter2\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter3\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter4\n",
      "torch.Size([100, 8175])\n",
      "Generating: eo lda100 iter5\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter6\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter7\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter8\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter9\n",
      "torch.Size([100, 8173])\n",
      "Generating: eo lda100 iter10\n",
      "torch.Size([100, 8173])\n"
     ]
    }
   ],
   "source": [
    "for q in range(len(dataset_names)):\n",
    "    dataset_name = dataset_names[q]\n",
    "    testset_list = pd.read_csv('data/output/'+ dataset_name+'_testset_list.csv')\n",
    "    for j in range(len(embed_types)):\n",
    "        for i in range(testset_list.shape[1]):\n",
    "            print(\"Generating: \"+ dataset_names[q] + \" \" + embed_types[j] + \" iter\" + str(i+1))\n",
    "            data = pd.read_csv(\"data/output/\" +dataset_name + '_' + embed_types[j] + '_full.csv', index_col=0)\n",
    "            idx_test = testset_list[\"iter\"+str(i+1)]\n",
    "            idx_train = Diff(range(1, len(data)), idx_test)\n",
    "            data = data.iloc[idx_train]\n",
    "            data = torch.tensor(data.values)\n",
    "            data = data.T\n",
    "            print(data.shape)\n",
    "            data = data.to(device)\n",
    "            #max_data = torch.max(data)\n",
    "            #min_data = torch.min(data)\n",
    "            #data = (data - min_data) / (max_data - min_data)\n",
    "            model = SparseModel(data.size()[1])\n",
    "            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "            loss_list = []\n",
    "            prev_eval_loss = 99999\n",
    "            for epoch in range(num_epochs):\n",
    "              model.train()\n",
    "              # ===================forward=====================\n",
    "              output = model(data.float())\n",
    "              reg21 = torch.sum(torch.norm(model.weights, p=2, dim=1))\n",
    "              loss = custom_loss(output, data) + lmbda * reg21\n",
    "              #print('L1 norm of L2 norm of weights: ', reg21.item())\n",
    "              # ===================backward====================\n",
    "              optimizer.zero_grad()\n",
    "              loss.backward()\n",
    "              optimizer.step()\n",
    "              # ===================log=======================\n",
    "              #print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))\n",
    "              model.eval()\n",
    "              eval_loss = F.l1_loss(model(data.float()), data.float()).item()\n",
    "              loss_list.append(eval_loss)\n",
    "              #print('Eval loss: ', eval_loss)\n",
    "              if eval_loss > prev_eval_loss:\n",
    "                break\n",
    "              prev_eval_loss = eval_loss\n",
    "            mw = model.weights.cpu().detach()\n",
    "            np.save(\"data/output/\" +'indices_'+ dataset_name +'_' + embed_types[j] + '_recon_iter' + str(i+1), torch.argsort(torch.norm(mw, p=2, dim=1), descending=True).numpy()[:3000])\n",
    "            #torch.save(model.state_dict(), './sim_autoencoder.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 374,
     "status": "ok",
     "timestamp": 1616889222499,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gi3hHA-32IVQPOzXK40Itcc5oZmMDf0Vsnw_e_afg=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 240
    },
    "id": "DTHSaVtmurD5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'## Parallel test\\n\\ndef background(f):\\n    def wrapped(*args, **kwargs):\\n        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)\\n    return wrapped\\n\\n\\n@background\\ndef recon_fun(q,i,j):\\n    dataset_name = dataset_names[q]\\n    testset_list = pd.read_csv(dataset_name+\\'_testset_list.csv\\')\\n    print(\"Generating: \"+ dataset_names[q] + \" \" + embed_types[j] + \" iter\" + str(i+1))\\n    data = pd.read_csv(dataset_name + \\'_\\' + embed_types[j] + \\'_full.csv\\', index_col=0)\\n    idx_test = testset_list[\"iter\"+str(i+1)]\\n    idx_train = Diff(range(1, len(data)), idx_test)\\n    data = data.iloc[idx_train]\\n    data = torch.tensor(data.values)\\n    data = data.T\\n    print(data.shape)\\n    data = data.to(device)\\n    #max_data = torch.max(data)\\n    #min_data = torch.min(data)\\n    #data = (data - min_data) / (max_data - min_data)\\n    model = SparseModel(data.size()[1])\\n    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\\n    loss_list = []\\n    prev_eval_loss = 99999\\n    for epoch in range(num_epochs):\\n      model.train()\\n      # ===================forward=====================\\n      output = model(data.float())\\n      reg21 = torch.sum(torch.norm(model.weights, p=2, dim=1))\\n      loss = custom_loss(output, data) + lmbda * reg21\\n      #print(\\'L1 norm of L2 norm of weights: \\', reg21.item())\\n      # ===================backward====================\\n      optimizer.zero_grad()\\n      loss.backward()\\n      optimizer.step()\\n      # ===================log=======================\\n      #print(\\'epoch [{}/{}], loss:{:.4f}\\'.format(epoch + 1, num_epochs, loss.item()))\\n      model.eval()\\n      eval_loss = F.l1_loss(model(data.float()), data.float()).item()\\n      loss_list.append(eval_loss)\\n      #print(\\'Eval loss: \\', eval_loss)\\n      if eval_loss > prev_eval_loss:\\n        break\\n      prev_eval_loss = eval_loss\\n    mw = model.weights.cpu().detach()\\n    np.save(\\'indices_\\'+ dataset_name +\\'_\\' + embed_types[j] + \\'_recon_iter\\' + str(i+1), torch.argsort(torch.norm(mw, p=2, dim=1), descending=True).numpy()[:3000])\\n    #torch.save(model.state_dict(), \\'./sim_autoencoder.pth\\')\\n\\n@background    \\ndef test_fun(q,i,j):\\n    print(\\'indices_\\'+ dataset_names[q] +\\'_\\' + embed_types[j] + \\'_recon_iter\\' + str(i+1))\\n\\n'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''## Parallel test\n",
    "\n",
    "def background(f):\n",
    "    def wrapped(*args, **kwargs):\n",
    "        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)\n",
    "    return wrapped\n",
    "\n",
    "\n",
    "@background\n",
    "def recon_fun(q,i,j):\n",
    "    dataset_name = dataset_names[q]\n",
    "    testset_list = pd.read_csv(dataset_name+'_testset_list.csv')\n",
    "    print(\"Generating: \"+ dataset_names[q] + \" \" + embed_types[j] + \" iter\" + str(i+1))\n",
    "    data = pd.read_csv(dataset_name + '_' + embed_types[j] + '_full.csv', index_col=0)\n",
    "    idx_test = testset_list[\"iter\"+str(i+1)]\n",
    "    idx_train = Diff(range(1, len(data)), idx_test)\n",
    "    data = data.iloc[idx_train]\n",
    "    data = torch.tensor(data.values)\n",
    "    data = data.T\n",
    "    print(data.shape)\n",
    "    data = data.to(device)\n",
    "    #max_data = torch.max(data)\n",
    "    #min_data = torch.min(data)\n",
    "    #data = (data - min_data) / (max_data - min_data)\n",
    "    model = SparseModel(data.size()[1])\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "    loss_list = []\n",
    "    prev_eval_loss = 99999\n",
    "    for epoch in range(num_epochs):\n",
    "      model.train()\n",
    "      # ===================forward=====================\n",
    "      output = model(data.float())\n",
    "      reg21 = torch.sum(torch.norm(model.weights, p=2, dim=1))\n",
    "      loss = custom_loss(output, data) + lmbda * reg21\n",
    "      #print('L1 norm of L2 norm of weights: ', reg21.item())\n",
    "      # ===================backward====================\n",
    "      optimizer.zero_grad()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "      # ===================log=======================\n",
    "      #print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))\n",
    "      model.eval()\n",
    "      eval_loss = F.l1_loss(model(data.float()), data.float()).item()\n",
    "      loss_list.append(eval_loss)\n",
    "      #print('Eval loss: ', eval_loss)\n",
    "      if eval_loss > prev_eval_loss:\n",
    "        break\n",
    "      prev_eval_loss = eval_loss\n",
    "    mw = model.weights.cpu().detach()\n",
    "    np.save('indices_'+ dataset_name +'_' + embed_types[j] + '_recon_iter' + str(i+1), torch.argsort(torch.norm(mw, p=2, dim=1), descending=True).numpy()[:3000])\n",
    "    #torch.save(model.state_dict(), './sim_autoencoder.pth')\n",
    "\n",
    "@background    \n",
    "def test_fun(q,i,j):\n",
    "    print('indices_'+ dataset_names[q] +'_' + embed_types[j] + '_recon_iter' + str(i+1))\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'for q in range(len(dataset_names)):\\n    for i in range(10):\\n        for j in range(len(embed_types)):\\n            recon_fun(q,i,j)\\n\\n'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''for q in range(len(dataset_names)):\n",
    "    for i in range(10):\n",
    "        for j in range(len(embed_types)):\n",
    "            recon_fun(q,i,j)\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KA-loUZNhrLv"
   },
   "source": [
    "| Embedding type | learning rate | lambda |\n",
    "| --- | --- | --- |\n",
    "| cvec_tsne16 | 0.0005 | 0.1 |\n",
    "| cvec_pca16 | 0.0005 | 0.1 |\n",
    "| cvec_umap16 | 0.0005 | 0.1 |\n",
    "| cvec_nmf16 | 0.005 | 0.01 |\n",
    "| bert | 0.01 | 0.01 |\n",
    "| distil | 0.01 | 0.01 |\n",
    "| roberta | 0.01 | 0.01 |\n",
    "| glove6B | 0.01 | 0.01 |\n",
    "| universal | 0.05 | 0.001 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'import itertools\\n\\n# Get the combinations of elements (like expand.grid)\\nlist1 = [range(len(dataset_names)), range(10), range(len(embed_types))]\\ncombinations = [p for p in itertools.product(*list1)]\\ncombos2 = [list(ele) for ele in combinations]\\ncombos3 = pd.DataFrame(combos2, columns = [\"q\", \"i\", \"j\"])\\n\\nprint(combos3)\\n\\n# Set up the multicore\\nimport parmap\\nimport multiprocessing as mp\\npool = mp.Pool(mp.cpu_count()-4)\\n\\nresults = parmap.starmap(recon_fun, zip(combos3[\\'q\\'].tolist(), combos3[\\'i\\'].tolist(), combos3[\\'j\\'].tolist()))\\n\\npool.close()'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''import itertools\n",
    "\n",
    "# Get the combinations of elements (like expand.grid)\n",
    "list1 = [range(len(dataset_names)), range(10), range(len(embed_types))]\n",
    "combinations = [p for p in itertools.product(*list1)]\n",
    "combos2 = [list(ele) for ele in combinations]\n",
    "combos3 = pd.DataFrame(combos2, columns = [\"q\", \"i\", \"j\"])\n",
    "\n",
    "print(combos3)\n",
    "\n",
    "# Set up the multicore\n",
    "import parmap\n",
    "import multiprocessing as mp\n",
    "pool = mp.Pool(mp.cpu_count()-4)\n",
    "\n",
    "results = parmap.starmap(recon_fun, zip(combos3['q'].tolist(), combos3['i'].tolist(), combos3['j'].tolist()))\n",
    "\n",
    "pool.close()'''"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "ReconstructionLoss.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
